{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.layers.core import Dense\n",
    "import collections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words):\n",
    "    count = [['GO', 0], ['PAD', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    count.extend(collections.Counter(words).most_common(n_words - 1))\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('from.txt', 'r') as fopen:\n",
    "    text_from = fopen.read().lower().split('\\n')\n",
    "with open('to.txt', 'r') as fopen:\n",
    "    text_to = fopen.read().lower().split('\\n')\n",
    "print('len from: %d, len to: %d'%(len(text_from), len(text_to)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concat_from = ' '.join(text_from).split()\n",
    "vocabulary_size_from = len(list(set(concat_from)))\n",
    "data_from, count_from, dictionary_from, rev_dictionary_from = build_dataset(concat_from, vocabulary_size_from)\n",
    "print('vocab from size: %d'%(vocabulary_size_from))\n",
    "print('Most common words', count_from[4:10])\n",
    "print('Sample data', data_from[:10], [rev_dictionary_from[i] for i in data_from[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concat_to = ' '.join(text_to).split()\n",
    "vocabulary_size_to = len(list(set(concat_to)))\n",
    "data_to, count_to, dictionary_to, rev_dictionary_to = build_dataset(concat_to, vocabulary_size_to)\n",
    "print('vocab to size: %d'%(vocabulary_size_to))\n",
    "print('Most common words', count_to[4:10])\n",
    "print('Sample data', data_to[:10], [rev_dictionary_to[i] for i in data_to[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary_from['GO']\n",
    "PAD = dictionary_from['PAD']\n",
    "EOS = dictionary_from['EOS']\n",
    "UNK = dictionary_from['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Chatbot:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, batch_size,\n",
    "                 from_dict_size, to_dict_size, grad_clip=5.0):\n",
    "        self.size_layer = size_layer\n",
    "        self.num_layers = num_layers\n",
    "        self.embedded_size = embedded_size\n",
    "        self.grad_clip = grad_clip\n",
    "        self.from_dict_size = from_dict_size\n",
    "        self.to_dict_size = to_dict_size\n",
    "        self.batch_size = batch_size\n",
    "        self.model = tf.estimator.Estimator(self.model_fn)\n",
    "        \n",
    "    def lstm_cell(self, reuse=False):\n",
    "        return tf.nn.rnn_cell.LSTMCell(self.size_layer, reuse=reuse)\n",
    "    \n",
    "    def seq2seq(self, x_dict, reuse):\n",
    "        x = x_dict['x']\n",
    "        x_seq_len = x_dict['x_len']\n",
    "        with tf.variable_scope('encoder', reuse=reuse):\n",
    "            encoder_embedding = tf.get_variable('encoder_embedding') if reuse else tf.get_variable('encoder_embedding', \n",
    "                                                                                                   [self.from_dict_size, self.embedded_size], \n",
    "                                                                                                   tf.float32, tf.random_uniform_initializer(-1.0, 1.0))\n",
    "            _, encoder_state = tf.nn.dynamic_rnn(\n",
    "                cell = tf.nn.rnn_cell.MultiRNNCell([self.lstm_cell() for _ in range(self.num_layers)]), \n",
    "                inputs = tf.nn.embedding_lookup(encoder_embedding, x),\n",
    "                sequence_length = x_seq_len,\n",
    "                dtype = tf.float32)\n",
    "            encoder_state = tuple(encoder_state[-1] for _ in range(self.num_layers))\n",
    "        if not reuse:\n",
    "            y = x_dict['y']\n",
    "            y_seq_len = x_dict['y_len']\n",
    "            with tf.variable_scope('decoder', reuse=reuse):\n",
    "                decoder_embedding = tf.get_variable(\n",
    "                    'decoder_embedding', [self.to_dict_size, self.embedded_size], tf.float32,\n",
    "                    tf.random_uniform_initializer(-1.0, 1.0))\n",
    "                helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                    inputs = tf.nn.embedding_lookup(decoder_embedding, y),\n",
    "                    sequence_length = y_seq_len,\n",
    "                    time_major = False)\n",
    "                decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                    cell = tf.nn.rnn_cell.MultiRNNCell([self.lstm_cell() for _ in range(self.num_layers)]),\n",
    "                    helper = helper,\n",
    "                    initial_state = encoder_state,\n",
    "                    output_layer = tf.layers.Dense(self.to_dict_size))\n",
    "                decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                    decoder = decoder,\n",
    "                    impute_finished = True,\n",
    "                    maximum_iterations = tf.reduce_max(y_seq_len))\n",
    "                return decoder_output.rnn_output\n",
    "        else:\n",
    "            with tf.variable_scope('decoder', reuse=reuse):\n",
    "                helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                    embedding = tf.get_variable('decoder_embedding'),\n",
    "                    start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [tf.shape(x)[0]]),\n",
    "                    end_token = EOS)\n",
    "                decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                    cell = tf.nn.rnn_cell.MultiRNNCell(\n",
    "                        [self.lstm_cell(reuse=True) for _ in range(self.num_layers)]),\n",
    "                    helper = helper,\n",
    "                    initial_state = encoder_state,\n",
    "                    output_layer = tf.layers.Dense(self.to_dict_size, _reuse=reuse))\n",
    "                decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                    decoder = decoder,\n",
    "                    impute_finished = True,\n",
    "                    maximum_iterations = 2 * tf.reduce_max(x_seq_len))\n",
    "                return decoder_output.sample_id\n",
    "            \n",
    "    def model_fn(self, features, labels, mode):\n",
    "        logits = self.seq2seq(features, reuse=False)\n",
    "        predictions = self.seq2seq(features, reuse=True)\n",
    "        if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "            return tf.estimator.EstimatorSpec(mode, predictions=predictions)\n",
    "        y_seq_len = features['y_len']\n",
    "        masks = tf.sequence_mask(y_seq_len, tf.reduce_max(y_seq_len), dtype=tf.float32)\n",
    "        loss_op = tf.contrib.seq2seq.sequence_loss(logits = logits, targets = labels, weights = masks)\n",
    "        params = tf.trainable_variables()\n",
    "        gradients = tf.gradients(loss_op, params)\n",
    "        clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip)\n",
    "        train_op = tf.train.AdamOptimizer().apply_gradients(zip(clipped_gradients, params),\n",
    "                                                            global_step=tf.train.get_global_step())\n",
    "        acc_op = tf.metrics.accuracy(labels=labels, predictions=predictions)\n",
    "        estim_specs = tf.estimator.EstimatorSpec(\n",
    "            mode = mode,\n",
    "            predictions = predictions,\n",
    "            loss = loss_op,\n",
    "            train_op = train_op,\n",
    "            eval_metric_ops = {'accuracy': acc_op})\n",
    "        return estim_specs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 256\n",
    "num_layers = 2\n",
    "embedded_size = 256\n",
    "batch_size = len(text_from)\n",
    "model = Chatbot(size_layer, num_layers, embedded_size, batch_size,\n",
    "                vocabulary_size_from + 4, vocabulary_size_to + 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i.split():\n",
    "            ints.append(dic.get(k,UNK))\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return np.array(padded_seqs).astype(np.int32), np.array(seq_lens).astype(np.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_x, seq_x = pad_sentence_batch(X, PAD)\n",
    "batch_y, seq_y = pad_sentence_batch(Y, PAD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_fn = tf.estimator.inputs.numpy_input_fn(\n",
    "            x={'x':batch_x, 'x_len':seq_x, 'y':batch_y, 'y_len':seq_y}, y=batch_y,\n",
    "            batch_size=batch_size, num_epochs=100, shuffle=False)\n",
    "model.model.train(input_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
